Modeling hierarchical data with multiple levels of nesting (e.g., individuals within groups within regions).
General Principles
To model hierarchies where groups are themselves nested within larger units (e.g., students within classes, which are within schools), we use Nested Varying Effects. This allows the model to share information across different levels of the hierarchy, improving estimates through multilevel pooling.
In a nested model, a parameter at one level (e.g., a group-level intercept) becomes the mean for the distribution of parameters at a lower level (e.g., individual-level intercepts).
This structure is often represented using multiple indices: j for the group and k(j) for the super-group (e.g., region) that group j belongs to.
To capture the correlation between multiple varying effects (e.g., intercept and slope) at any level of the hierarchy, we use a Multivariate Normal (MVN) distribution.
Non-centered parameterization is highly recommended for nested models to avoid βfunnelβ geometries that can hinder MCMC sampling.
Example
Below is an example of a nested model with both varying intercepts and varying slopes. We model the outcome y for individuals in groups nested within regions. The relationship between x and y varies at both levels.
from BI import biimport jax.numpy as jnpimport numpy as np# Setup device -----------------------------------------------m = bi(platform='cpu')# Load data β column args are mapped automatically from the DataFrame.# N_groups and N_regions are dataset-specific; provide them as defaults.data_path = m.load.sim_nested_effects(only_path=True)m.data(data_path)# Define model ------------------------------------------------def model_nested(y, x, group_id, region_id, N_groups=20, N_regions=5): sigma = m.dist.exponential(1, name='sigma')# 1. Region level mu_global = jnp.stack([m.dist.normal(5, 2, name='global_intercept'), m.dist.normal(-1, 1, name='global_beta')]) sigma_reg = m.dist.exponential(1, shape=(2,), name='sigma_region') corr_reg = m.dist.lkj(2, 2, name='corr_region') cov_reg = jnp.diag(sigma_reg) @ corr_reg @ jnp.diag(sigma_reg) region_effects = m.dist.multivariate_normal( mu_global, cov_reg, shape=(N_regions,), name='region_effects' )# 2. Group level β parent mapping via JAX scatter (traceable under pmap) group_to_region = jnp.zeros(N_groups, dtype=jnp.int32).at[group_id].set(region_id) sigma_grp = m.dist.exponential(1, shape=(2,), name='sigma_group') corr_grp = m.dist.lkj(2, 2, name='corr_group') cov_grp = jnp.diag(sigma_grp) @ corr_grp @ jnp.diag(sigma_grp) group_effects = m.dist.multivariate_normal( region_effects[group_to_region], cov_grp, name='group_effects' ) mu_est = group_effects[group_id, 0] + group_effects[group_id, 1] * x m.dist.normal(mu_est, sigma, obs=y)# Run sampler ------------------------------------------------m.fit(model_nested, num_samples=1000, num_warmup=500, num_chains=1)m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:
scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:
scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
0%| | 0/1500 [00:00<?, ?it/s]warmup: 0%| | 1/1500 [00:01<25:55, 1.04s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup: 3%|β | 51/1500 [00:01<00:23, 61.06it/s, 511 steps of size 1.87e-02. acc. prob=0.75]warmup: 7%|β | 101/1500 [00:01<00:11, 125.30it/s, 255 steps of size 3.20e-01. acc. prob=0.77]warmup: 10%|β | 149/1500 [00:01<00:07, 186.11it/s, 255 steps of size 5.46e-02. acc. prob=0.77]warmup: 13%|ββ | 193/1500 [00:01<00:05, 236.22it/s, 255 steps of size 6.08e-02. acc. prob=0.77]warmup: 17%|ββ | 254/1500 [00:01<00:03, 317.61it/s, 511 steps of size 1.23e-02. acc. prob=0.77]warmup: 20%|ββ | 303/1500 [00:01<00:03, 356.29it/s, 255 steps of size 3.94e-02. acc. prob=0.78]warmup: 25%|βββ | 370/1500 [00:01<00:02, 434.86it/s, 255 steps of size 4.21e-02. acc. prob=0.78]warmup: 30%|βββ | 445/1500 [00:01<00:02, 517.82it/s, 63 steps of size 4.01e-02. acc. prob=0.78] sample: 34%|ββββ | 506/1500 [00:01<00:01, 532.19it/s, 127 steps of size 4.82e-02. acc. prob=0.94]sample: 39%|ββββ | 587/1500 [00:02<00:01, 608.76it/s, 127 steps of size 4.82e-02. acc. prob=0.86]sample: 44%|βββββ | 654/1500 [00:02<00:01, 623.75it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample: 48%|βββββ | 724/1500 [00:02<00:01, 645.20it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 53%|ββββββ | 796/1500 [00:02<00:01, 665.20it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample: 58%|ββββββ | 873/1500 [00:02<00:00, 694.84it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 63%|βββββββ | 946/1500 [00:02<00:00, 704.13it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 68%|βββββββ | 1018/1500 [00:02<00:00, 708.45it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 73%|ββββββββ | 1090/1500 [00:02<00:00, 711.56it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 77%|ββββββββ | 1162/1500 [00:02<00:00, 713.89it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 82%|βββββββββ | 1236/1500 [00:02<00:00, 718.52it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 87%|βββββββββ | 1310/1500 [00:03<00:00, 723.00it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 92%|ββββββββββ| 1385/1500 [00:03<00:00, 729.04it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 97%|ββββββββββ| 1459/1500 [00:03<00:00, 714.44it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 100%|ββββββββββ| 1500/1500 [00:03<00:00, 448.47it/s, 127 steps of size 4.82e-02. acc. prob=0.89]
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:
invalid value encountered in scalar divide
mean
sd
hdi_5.5%
hdi_94.5%
mcse_mean
mcse_sd
ess_bulk
ess_tail
r_hat
L_group[0, 0]
1.00
0.00
1.00
1.00
0.00
0.00
3000.00
3000.00
NaN
L_group[0, 1]
0.00
0.00
0.00
0.00
0.00
0.00
3000.00
3000.00
NaN
L_group[1, 0]
0.42
0.23
0.03
0.75
0.01
0.01
584.84
675.29
NaN
L_group[1, 1]
0.87
0.11
0.72
1.00
0.00
0.00
581.67
675.29
NaN
L_region[0, 0]
1.00
0.00
1.00
1.00
0.00
0.00
3000.00
3000.00
NaN
...
...
...
...
...
...
...
...
...
...
z_region[1, 0]
-0.71
0.66
-1.71
0.38
0.02
0.02
831.50
720.19
NaN
z_region[1, 1]
1.05
0.65
0.09
2.11
0.02
0.02
934.02
733.74
NaN
z_region[1, 2]
-0.58
0.66
-1.62
0.42
0.03
0.02
661.19
724.94
NaN
z_region[1, 3]
0.39
0.79
-0.75
1.73
0.02
0.02
1060.03
744.14
NaN
z_region[1, 4]
-0.02
0.70
-1.10
1.08
0.02
0.02
884.18
604.35
NaN
65 rows Γ 9 columns
Code
from BI import biimport jax.numpy as jnp# Setup device -----------------------------------------------m = bi(platform='cpu')# Load data β column args are mapped automatically from the DataFrame.# N_groups and N_regions are dataset-specific; provide them as defaults.data_path = m.load.sim_nested_effects(only_path=True)m.data(data_path)# group_ids: one obs-level index array per level, top to bottom.# Parent structure is derived automatically from these arrays.def model_nested_builtin(y, x, group_id, region_id, N_groups=20, N_regions=5): sigma = m.dist.exponential(1, name='sigma') a_g_est, b_g_est = m.effects.nested_varying_effects( N_vars=2, names=["region", "group"], N_groups=[N_regions, N_groups], group_ids=[region_id, group_id], centered=False, ) mu_est = a_g_est + b_g_est * x m.dist.normal(mu_est, sigma, obs=y)m.fit(model_nested_builtin, num_samples=1000, num_warmup=500, num_chains=1)m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:
scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
from BI import biimport jax.numpy as jnp# Setup device -----------------------------------------------m = bi(platform='cpu')# Load data β column args are mapped automatically from the DataFrame.# N_groups and N_regions are dataset-specific; provide them as defaults.data_path = m.load.sim_nested_effects(only_path=True)m.data(data_path)def model_nested_builtin_centered(y, x, group_id, region_id, N_groups=20, N_regions=5): sigma = m.dist.exponential(1, name='sigma') a_g_est, b_g_est = m.effects.nested_varying_effects( N_vars=2, names=["region", "group"], N_groups=[N_regions, N_groups], group_ids=[region_id, group_id], centered=True, ) mu_est = a_g_est + b_g_est * x m.dist.normal(mu_est, sigma, obs=y)m.fit(model_nested_builtin_centered, num_samples=1000, num_warmup=500, num_chains=1)m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:
scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
Where: - \Sigma = \mathbf{S} \mathbf{R} \mathbf{S} (diagonal matrix of standard deviations \mathbf{S} and correlation matrix \mathbf{R}). - \bar{\gamma}_\alpha, \bar{\gamma}_\beta are the global average intercept and slope. - \Sigma_{region}, \Sigma_{group} are the covariance matrices for each hierarchical level.
Non-Centered Parameterization
To avoid funnel geometries, the non-centered formulation instead uses Cholesky factors L_{region} and L_{group} alongside standardized, non-correlated normal variables:
By nesting the MVN prior, we allow the model to learn how the correlation between intercepts and slopes persists across different levels of the hierarchy.
The non-centered parameterization (used by default in the m.effects.nested_varying_effects built-in function) is critical for convergence in these complex models.
This approach can be generalized to N levels and M variables by increasing the dimensionality of the vectors and matrices.